#include "topk.h"
#include "util.h"
#include <iostream>

namespace wmsketch {

LogisticTopK::LogisticTopK(uint32_t k, uint32_t dim, float lr_init, float l2_reg, bool no_bias)
 : TopKFeatures(k),
   lr_(dim, lr_init, l2_reg, no_bias) { }

LogisticTopK::~LogisticTopK() = default;

bool LogisticTopK::predict(const std::vector<std::pair<uint32_t, float> >& x) {
  return lr_.predict(x);
}

bool LogisticTopK::update(const std::vector<std::pair<uint32_t, float> >& x, bool label) {
  bool yhat = lr_.update(new_weights_, x, label);
  for (int i = 0; i < x.size(); i++) {
    uint32_t key = x[i].first;
    heap_.insert_or_change(key, new_weights_[i]);
  }
  return yhat;
}

float LogisticTopK::bias() {
  return lr_.bias();
}

///////////////////////////////////////////////////////////////////////////////

BlackBoxReductionTopK::BlackBoxReductionTopK(
		uint32_t k,
		uint32_t log2_width,
		uint32_t depth,
		int32_t seed,
		float lr_init,
		float l2_reg)
 : TopKFeatures(k),
   sk_(log2_width, depth, seed, lr_init, l2_reg),
   t_{0} { }

BlackBoxReductionTopK::~BlackBoxReductionTopK() = default;

void BlackBoxReductionTopK::topk(std::vector<std::pair<uint32_t, float>>& out) {
	refresh_heap();
	TopKFeatures::topk(out);
	float s = sk_.scale();
	for (auto& i : out) {
		i.second *= s;
	}
}

bool BlackBoxReductionTopK::predict(const std::vector<std::pair<uint32_t, float>>& x) {
	return sk_.predict(x);
}

bool BlackBoxReductionTopK::update(const std::vector<std::pair<uint32_t, float>>& x, bool label) {
	bool yhat = sk_.update(new_weights_, x, label);
	for (int i = 0; i < x.size(); i++) {
		uint32_t key = x[i].first;
		heap_.insert_or_change(key, new_weights_[i]);
	}
	t_++;
	return yhat;
}

float BlackBoxReductionTopK::bias() {
	return sk_.bias();
}

void BlackBoxReductionTopK::refresh_heap() {
	heap_.keys(idxs_);
	for (uint32_t idx : idxs_) {
		heap_.change_val(idx, sk_.get(idx)); // estimated by CountSketch.
	}
}

///////////////////////////////////////////////////////////////////////////////

JLRecoverySketchTopK::JLRecoverySketchTopK(
	uint32_t k,
	uint32_t log2_width,
	uint32_t depth,
	int32_t seed,
	float lr_init,
	float l2_reg)
 : TopKFeatures(k),
   sk_(log2_width, depth, seed, lr_init, l2_reg),
   t_{0} { }

JLRecoverySketchTopK::~JLRecoverySketchTopK() = default;

void JLRecoverySketchTopK::topk(std::vector<std::pair<uint32_t, float>>& out) {
	refresh_heap();
	TopKFeatures::topk(out);
	float s = sk_.scale();
	for (auto& i : out) {
		i.second *= s;
	}
}

bool JLRecoverySketchTopK::predict(const std::vector<std::pair<uint32_t, float>>& x) {
	return sk_.predict(x);
}

bool JLRecoverySketchTopK::update(const std::vector<std::pair<uint32_t, float>>& x, bool label) {
	bool yhat = sk_.update(new_weights_, x, label);
	for (int i = 0; i < x.size(); i++) {
		uint32_t key = x[i].first;
		heap_.insert_or_change(key, new_weights_[i]);
	}
	t_++;
	return yhat;
}

float JLRecoverySketchTopK::bias() {
	return sk_.bias();
}

void JLRecoverySketchTopK::refresh_heap() {
	heap_.keys(idxs_);
	for (uint32_t idx : idxs_) {
		heap_.change_val(idx, sk_.get(idx));
	}
}

///////////////////////////////////////////////////////////////////////////////

LogisticSketchTopK::LogisticSketchTopK(
    uint32_t k,
    uint32_t log2_width,
    uint32_t depth,
    int32_t seed,
    float lr_init,
    float l2_reg,
    bool median_update)
 : TopKFeatures(k),
   sk_(log2_width, depth, seed, lr_init, l2_reg, median_update),
   t_{0} { }

LogisticSketchTopK::~LogisticSketchTopK() = default;

void LogisticSketchTopK::topk(std::vector<std::pair<uint32_t, float> >& out) {
  refresh_heap();
  TopKFeatures::topk(out);
  float s = sk_.scale();
  for (auto& i : out) {
    i.second *= s;
  }
}

bool LogisticSketchTopK::predict(const std::vector<std::pair<uint32_t, float> >& x) {
  return sk_.predict(x);
}

bool LogisticSketchTopK::update(const std::vector<std::pair<uint32_t, float> >& x, bool label) {
  bool yhat = sk_.update(new_weights_, x, label);
  for (int i = 0; i < x.size(); i++) {
    uint32_t key = x[i].first;
    heap_.insert_or_change(key, new_weights_[i]);
  }
  t_++;
  return yhat;
}

float LogisticSketchTopK::bias() {
  return sk_.bias();
}

void LogisticSketchTopK::refresh_heap() {
  heap_.keys(idxs_);
  for (uint32_t idx : idxs_) {
    heap_.change_val(idx, sk_.get(idx));
  }
}


///////////////////////////////////////////////////////////////////////////////

} // namespace wmsketch
